from distributed_pcg.utils import read_dataset
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import cvxpy as cp
import pickle



def create_data(n,d,l,seed,a):
    np.random.seed(seed)
    A = np.random.normal(size=(n,d))
    AtA  = A.T@A 
    U,_,Vt = np.linalg.svd(AtA)
    alpha = np.diag([(a)**i for i in range(d)])
    A = U@(alpha)**0.5@Vt
    AtA = A.T@A 
    eff_dim = np.trace((AtA/n)@np.linalg.inv((AtA/n)+l*np.eye(d)))

    return A, eff_dim
    

def compute_inverse(A,dl,m_list,l):
    original_list = []
    ours_list = []
    for i in range(len(m_list)):
        n = A.shape[0]
        d = A.shape[1]
        H = A.T@A/n
        m = m_list[i]
        print(f'{m=}')
        S = np.random.normal(loc=0,scale=1/(m**0.5),size=(m,d))
        true = np.linalg.inv(H+l*np.eye(d))

        def Sm(lam):
            sketch_dim  = S.shape[0]
            return np.trace(np.linalg.inv(S@H@S.T+lam*np.eye(sketch_dim)))/sketch_dim


        # compute \hat \lambda 
        init_range = np.array([5*l/12,l])
        try:
            assert Sm(init_range[0])>=1/l
            assert Sm(init_range[1])<=1/l
        except:
            # check for inverse error 
            q = 500 
            no_debias = 0 
            for i in range(q):
                S = np.random.normal(loc=0,scale=1/(m**0.5),size=(m,d))
                no_debias += S.T@np.linalg.inv(S@H@S.T+l*np.eye(m))@S 
            no_debias = no_debias/q
            original_list.append(((true-no_debias)**2).sum()/d**2)
            ours_list.append(-1)
            continue
        while np.abs(Sm(init_range.mean())-1/l)>1e-3:
            if Sm(init_range.mean())>1/l:
                init_range[0] = init_range.mean()
            elif Sm(init_range.mean())<1/l:
                init_range[1] = init_range.mean()
            
        hat_l = init_range.mean()

        # check for inverse error 
        q = 500 
        tilde_l = l*(1-dl/m)
        no_debias = 0 
        debias = 0
        ours = 0
        for i in range(q):
            S = np.random.normal(loc=0,scale=1/(m**0.5),size=(m,d))
            no_debias += S.T@np.linalg.inv(S@H@S.T+l*np.eye(m))@S 
            ours += S.T@np.linalg.inv(S@H@S.T+hat_l*np.eye(m))@S
        no_debias = no_debias/q
        ours = ours/q

        true = np.linalg.inv(H+l*np.eye(d))
        original_list.append(((true-no_debias)**2).sum()/(d**2))
        ours_list.append(((true-ours)**2).sum()/(d**2))
    return original_list, ours_list



def get_x_axis(data_set):
    x_max = -1
    length = 0
    for i in data_set: 
        if i[-1][0]>x_max:
            length = len(i)
            x_max = i[-1][0]
    if length == 2:
        return np.linspace(0, x_max, 2)
    return np.linspace(0, x_max, max(length,100))

def interpolate(data):
    numbers = np.zeros((len(data), len(data[0])))
    for i in range(len(data)):
        numbers[i] = np.array(data[i])
    mean = np.quantile(data, 0.5, axis=0)
    error_l = np.quantile(data, 0.2, axis=0)
    error_u = np.quantile(data, 0.8, axis=0)  
    return (mean, error_l, error_u) 


def plot_multi_realdata():
    n=10000; d = 1000; m_list2 = [5,10,20,50,100,200,300,500]; l=1e-5 #l=1e-3
    dl_list = []
    dl_b_list = [] 
    for i in range(10):
        print(i)
        A, eff_dm = create_data(n,d,l,seed=i,a=0.9+np.random.rand()*1e-2)
        original, ours = compute_inverse(A,eff_dm,m_list2,l)
        dl_list.append(np.array(original))
        dl_b_list.append(ours)

    m_list2 = m_list2[4:]
    dl_list =  [np.array(i[4:]) for i in dl_list]
    dl_b_list =  [np.array(i[4:]) for i in dl_b_list]
    dl_plot_data = interpolate(dl_list)
    dl_b_plot_data = interpolate(dl_b_list)
    plt.figure(figsize=(100, 100))
    fig, ax = plt.subplots()
    clrs = sns.color_palette("husl", 10)
    ax.plot(np.array(m_list2),dl_plot_data[0], label=r'$\frac{1}{d^2}\|H(\lambda)^{-1}-H_S(\lambda)^{-1}\|_F^2$ (original)', c=clrs[3])
    ax.fill_between(np.array(m_list2),dl_plot_data[1], dl_plot_data[2],alpha=0.3, facecolor=clrs[3])
    ax.plot(np.array(m_list2),dl_b_plot_data[0], label=r'$\frac{1}{d^2}\|H(\lambda)^{-1}-H_S(\hat \lambda)^{-1}\|_F^2$ (ours)', c=clrs[1])
    ax.fill_between(np.array(m_list2),dl_b_plot_data[1], dl_b_plot_data[2],alpha=0.3, facecolor=clrs[1])
    ax.legend(fontsize=18, loc="upper right")
    plt.ticklabel_format(axis='y', style='sci', scilimits=(3,3))
    plt.xlabel('m', fontsize=20)
    plt.ylabel('bias', fontsize=20)
    plt.xticks(fontsize=17)
    plt.yticks(fontsize=17)
    plt.tight_layout()
    plt.savefig('bias_vs_m.pdf')




if __name__ == '__main__':
    plot_multi_realdata()

